#!/usr/bin/env python3
# Copyright 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import sys

from setuptools import find_packages, setup

if "--plat-name" in sys.argv:
    PLATFORM_FLAG = sys.argv[sys.argv.index("--plat-name") + 1]
else:
    PLATFORM_FLAG = "any"

if "VERSION" not in os.environ:
    raise Exception("envvar VERSION must be specified")

VERSION = os.environ["VERSION"]

try:
    from wheel.bdist_wheel import bdist_wheel as _bdist_wheel

    class bdist_wheel(_bdist_wheel):
        def finalize_options(self):
            _bdist_wheel.finalize_options(self)
            self.root_is_pure = False

        def get_tag(self):
            pyver, abi, plat = "py3", "none", PLATFORM_FLAG
            return pyver, abi, plat

except ImportError:
    bdist_wheel = None

this_directory = os.path.abspath(os.path.dirname(__file__))

data_files = [
    ("", ["LICENSE.txt"]),
]

# Type checking marker file indicating support for type checkers.
# https://peps.python.org/pep-0561/
# Type hints for c extension generated by mypy
platform_package_data = [
    os.environ["TRITON_PYBIND"],
    "py.typed",
    "_c/__init__.pyi",
    "_c/triton_bindings.pyi",
]

gpu_extras = ["cupy-cuda12x"]
test_extras = ["pytest"]
all_extras = gpu_extras + test_extras

setup(
    name="tritonfrontend",
    version=VERSION,
    author="NVIDIA Inc.",
    author_email="sw-dl-triton@nvidia.com",
    description="Triton Inference Server In-Process Python API",
    license="BSD",
    url="https://developer.nvidia.com/nvidia-triton-inference-server",
    classifiers=[
        "Development Status :: 5 - Production/Stable",
        "Intended Audience :: Developers",
        "Intended Audience :: Science/Research",
        "Intended Audience :: Information Technology",
        "Topic :: Scientific/Engineering",
        "Topic :: Scientific/Engineering :: Image Recognition",
        "Topic :: Scientific/Engineering :: Artificial Intelligence",
        "Topic :: Software Development :: Libraries",
        "Topic :: Utilities",
        "License :: OSI Approved :: BSD License",
        "Programming Language :: Python :: 3",
        "Programming Language :: Python :: 3.10",
        "Environment :: Console",
        "Natural Language :: English",
        "Operating System :: OS Independent",
    ],
    packages=find_packages(),
    package_data={
        "": platform_package_data,
    },
    zip_safe=False,
    cmdclass={"bdist_wheel": bdist_wheel},
    data_files=data_files,
    install_requires=["tritonserver", "pydantic==2.10.6"],
    extras_require={"GPU": gpu_extras, "test": test_extras, "all": all_extras},
)
